from tokenDef import *
from jsAst.types import *

CONST_TOKEN_MAP={
    'true':jBool(True),
    'false':jBool(False),
    'null':jNull,
}

class Scanner:
    def __init__(self,content) -> None:
        self._dealContent(content)
        self.offset=0
        self.size=len(self.content)
        self.line=0
        self.mapping={
        TOKEN['ADD']:self.ADD,
        TOKEN['SUB']:self.SUB,
        TOKEN['MUL']:self.MUL,
        TOKEN['DIV']:self.DIV,
        TOKEN['QST']:self.Default,
        TOKEN['COL']:self.Default,
        TOKEN['GT']:self.GT,
        TOKEN['LT']:self.LT,
        TOKEN['EQ']:self.EQ,
        TOKEN['NOT']:self.NOT,
        '&':self.SameTwo,
        '|':self.SameTwo,
        TOKEN['SEMI']:self.Default,
        TOKEN['COMMA']:self.COMMA,
        TOKEN['ASSIGN']:self.Default,
        TOKEN['LPT']:self.Default,
        TOKEN['RPT']:self.Default,
        TOKEN['LBR']:self.Default,
        TOKEN['RBR']:self.Default,
        TOKEN['LMBR']:self.Default,
        TOKEN['RMBR']:self.Default,
        '.':self.Default,
        '"':self.String,
        "'":self.String,
    }
        self.cur=None
    
    def readToken(self):
        if self.cur!=None:
            t=self.cur
            self.cur=None
            return t
        return self._readToken()

    def retToken(self,token):
        self.cur=token

    def _readToken(self):
        while self.offset<self.size and (self.content[self.offset]==' ' or self.content[self.offset]=='\n'):
            self.offset+=1
            if self.content[self.offset]=='\n':
                self.line+=1
                return Node(op=TOKEN['LF'])
        if self.offset==self.size:
            return None
        word=self.content[self.offset:self.offset+2]
        if word[0] in self.mapping:
            res=self.mapping[word[0]](word)
            return res
        l=0
        oriToken=None
        while self.offset+l<self.size:
            word=self.content[self.offset+l]
            if word in self.mapping or word==' ' or word=='\n':
                if word=='\n':
                    self.line+=1
                    l+=1
                    continue
                target=self.content[self.offset:self.offset+l]
                if oriToken==None:
                    oriToken=self.findType(target)
                    if isinstance(oriToken,jInt) and word=='.':
                        self.offset-=l
                    else: return oriToken
                else:
                    if isinstance(oriToken,jInt):
                        self.offset+=l
                        return jFloat(target)
            l+=1
        word=self.content[self.offset:self.offset+l]
        return self.findType(word)

    def findType(self,word):
        self.offset+=len(word)
        if word[0].isdigit():
            return jInt(word)
        elif word in OPT_TOKEN:
            return Node(op=word)
        elif word.lower() in CONST_TOKEN:
            return CONST_TOKEN_MAP[word.lower()]
        else:
            return jVar(word)

    def _dealContent(self,content):
        data=content.split('\n')
        data=[i.strip() for i in data]
        self.content='\n'.join(data)

    def ADD(self,key):
        if len(key)>1 and key[1]=='=':
            self.offset+=2
            return Node(op='+=')
        else:
            self.offset+=1
            return Node(op='+')

    def SUB(self,key):
        if len(key)>1 and key[1]=='=':
            self.offset+=2
            return Node(op='-=')
        else:
            self.offset+=1
            return Node(op='-')

    def MUL(self,key):
        if len(key)==1:
            self.offset+=1
            return Node(op=key[0])
        elif key[1]=='=' or key[1]=='*':
            self.offset+=2
            return Node(op=key)
        else:
            self.offset+=1
            return Node(op=key[0])

    def DIV(self,key):
        if len(key)==1:
            self.offset+=1
            return Node(op='/')
        elif key[1]=='=':
            self.offset+=2
            return Node(op=key)
        elif key[1]=='/':
            n=self.content.find('\n',self.offset)
            self.line+=1
            if n==-1:
                self.offset=self.size
                return None
            else:
                self.offset=n+1
                return self.readToken()
        elif key[1]=='*':
            n=self.content.find('*/',self.offset)
            if n==-1:
                self.offset=self.size
                return None
            else:
                self.offset=n+2
                return self.readToken()
        else:
            self.offset+=1
            return Node(op='/')
    def GT(self,key):
        if len(key)>1 and key[1]=='=':
            self.offset+=2
            return Node(op=key)
        else:
            self.offset+=1
            return Node(op=key[0])

    def LT(self,key):
        if len(key)>1 and key[1]=='=':
            self.offset+=2
            return Node(op=key)
        else:
            self.offset+=1
            return Node(op=key[0])

    def EQ(self,key):
        if len(key)>1 and key[1]=='=':
            self.offset+=2
            return Node(op=key)
        else:
            self.offset+=1
            return Node(op=key[0])

    def NOT(self,key):
        if len(key)>1 and key[1]=='=':
            self.offset+=2
            return Node(op=key)
        else:
            self.offset+=1
            return Node(op=key[0])

    def DOT(self,key):
        if len(key)==1:
            raise Exception('Error: except oprator for "."')
        if key[1].isdigit():#是省略了 0 的小数
            length=2
            while self.content[self.offset+length].isdigit():
                length+=1
            data=jFloat('0'+self.content[self.offset:self.offset+length])
            self.offset+=length
            return data
        else:
            return Node(op=key[0])

    def Default(self,key):
        self.offset+=1
        return Node(op=key[0])

    def SameTwo(self,key):
        if len(key)!=2 or key[1]!=key[0]:
            raise Exception('Error: '+key+' is not support')
        self.offset+=2
        return Node(op=key)
    
    def COMMA(self,_):
        self.offset+=1
        return self.readToken()
    
    def String(self,key):
        end=key[0]
        content=''
        self.offset+=1
        while self.offset<self.size:
            word=self.content[self.offset]
            if word==end:
                self.offset+=1
                break
            elif word=='\\':
                next=self.content[self.offset+1]
                if next=='n':
                    content+='\n'
                    self.offset+=2
                elif next=='t':
                    content+='\t'
                    self.offset+=2
                elif next=='r':
                    content+='\r'
                    self.offset+=2
                elif next=='"':
                    content+='"'
                    self.offset+=2
                elif next=="'":
                    content+="'"
                    self.offset+=2
                else:
                    content+=word
                    self.offset+=1
            else:
                content+=word
                self.offset+=1
        if self.offset==self.size-1 and word!=end:
            raise Exception('Error: unclosed string')
        return jString(content)

if __name__ == '__main__':
    a=open('test.js','r',encoding='utf-8')
    a=a.read()
    p=Scanner(a)
    while True:
        res=p.readToken()
        if res==None:
            break
        print(res)